#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Milvus集合信息查询脚本
用于查询Milvus collection的schema和节点总数
"""
import sys
import os
from pathlib import Path

# 添加项目根目录到系统路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.append(str(project_root))

# 尝试导入pymilvus库的不同客户端类
try:
    from pymilvus import connections, utility, Collection
    # 新版API - MilvusClient
    try:
        from pymilvus import MilvusClient
        have_milvus_client = True
    except ImportError:
        have_milvus_client = False
except ImportError as e:
    print(f"错误: 导入pymilvus失败 - {e}")
    print("请执行: pip install -r backend/requirements.txt")
    sys.exit(1)

# 添加backend目录到系统路径（仅用于测试脚本）
backend_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(backend_dir))

# 导入配置管理器
try:
    from app.utils.config_manager import config
except ImportError as e:
    print(f"错误: 导入配置管理器失败 - {e}")
    print("请确保config_manager.py文件存在于正确的路径")
    sys.exit(1)


def connect_to_milvus():
    """连接到Milvus服务器"""
    milvus_uri = config.milvus_config.get("uri")
    milvus_token = config.milvus_config.get("token")
    
    if not milvus_uri:
        print("错误: 未配置Milvus URI，请检查配置文件")
        return False, None
    
    # 首先尝试使用新版API MilvusClient
    if have_milvus_client:
        try:
            print("尝试使用 MilvusClient 连接...")
            if milvus_token:
                client = MilvusClient(uri=milvus_uri, token=milvus_token)
            else:
                client = MilvusClient(uri=milvus_uri)
            print(f"成功使用 MilvusClient 连接到Milvus服务器: {milvus_uri}")
            return True, client
        except Exception as e:
            print(f"使用 MilvusClient 连接失败: {e}")
            print("尝试使用传统连接方式...")
    
    # 传统连接方式
    try:
        if milvus_token:
            connections.connect("default", uri=milvus_uri, token=milvus_token)
        else:
            connections.connect("default", uri=milvus_uri)
        print(f"成功连接到Milvus服务器: {milvus_uri}")
        return True, None
    except Exception as e:
        print(f"连接Milvus服务器失败: {e}")
        return False, None


def get_all_collections(client=None):
    """获取所有集合名称"""
    try:
        if client:
            # 使用MilvusClient
            collections = client.list_collections()
        else:
            # 使用传统API
            collections = utility.list_collections()
        return collections
    except Exception as e:
        print(f"获取集合列表失败: {e}")
        return []


def get_collection_info_via_client(client, collection_name):
    """使用MilvusClient获取集合信息"""
    try:
        # 检查集合是否存在
        if collection_name not in client.list_collections():
            print(f"集合 {collection_name} 不存在")
            return False
        
        # 获取集合信息
        try:
            # 尝试使用query获取行数
            try:
                # 使用count()，不带limit参数
                query_result = client.query(
                    collection_name=collection_name,
                    filter="",
                    output_fields=["count(*)"]
                )
                if query_result and len(query_result) > 0:
                    row_count = query_result[0].get("count(*)", "未知")
                else:
                    row_count = "未知"
            except Exception as e:
                print(f"查询行数失败: {e}")
                # 尝试新方法获取实体数量
                try:
                    # 某些版本可能有count_entities方法
                    row_count = client.count_entities(collection_name=collection_name)
                except Exception:
                    try:
                        # 尝试获取集合信息中的行数
                        collection_info = client.describe_collection(collection_name=collection_name)
                        if "row_count" in collection_info:
                            row_count = collection_info["row_count"]
                        elif "num_entities" in collection_info:
                            row_count = collection_info["num_entities"]
                        else:
                            # 如果还是无法获取，尝试加载一条记录
                            try:
                                data = client.get(collection_name=collection_name, ids=["0"], limit=1)
                                if data and len(data) > 0:
                                    row_count = "至少有数据"
                                else:
                                    row_count = "可能为空"
                            except:
                                row_count = "未知"
                    except:
                        row_count = "无法获取"
        except Exception as e:
            print(f"获取集合统计数据失败: {e}")
            row_count = "未知"
        
        # 获取集合字段信息
        try:
            # describe_collection方法在所有版本中都应该存在
            schema = client.describe_collection(collection_name=collection_name)
            fields = schema.get("schema", {}).get("fields", [])
            if not fields and "fields" in schema:
                # 可能在不同版本中结构不同
                fields = schema.get("fields", [])
        except Exception as e:
            print(f"获取集合schema失败: {e}")
            fields = []
        
        # 打印集合信息
        print(f"\n集合名称: {collection_name}")
        print(f"实体数量: {row_count}")
        print("字段信息:")
        
        for field in fields:
            field_name = field.get("name", "未知")
            field_type = field.get("type", "未知")
            dim_info = ""
            if field.get("type") == "FLOAT_VECTOR" or str(field.get("type")).upper() == "VECTOR":
                dim = None
                # 尝试不同的位置获取维度信息
                if "params" in field and "dim" in field["params"]:
                    dim = field["params"]["dim"]
                elif "dimension" in field:
                    dim = field["dimension"]
                elif "type_params" in field and "dim" in field["type_params"]:
                    dim = field["type_params"]["dim"]
                
                if dim:
                    dim_info = f", 维度={dim}"
                else:
                    dim_info = ", 维度=未知"
                    
            print(f"  - {field_name}: 类型={field_type}{dim_info}")
        
        # 获取索引信息 - 由于没有直接的API，我们只能尝试获取描述
        try:
            # 在某些版本中可能有describe_index方法
            has_indexes = False
            for field in fields:
                field_name = field.get("name", "")
                if not field_name or field_name == "count":
                    continue
                
                try:
                    # 尝试获取字段的索引信息，不同版本API可能不同
                    index_info = None
                    
                    # 尝试不同的方法获取索引信息
                    methods_to_try = [
                        lambda: client.describe_index(collection_name, field_name),
                        lambda: client.get_index_info(collection_name, field_name),
                        lambda: client.get_index_state(collection_name, field_name)
                    ]
                    
                    for method in methods_to_try:
                        try:
                            index_info = method()
                            if index_info:
                                break
                        except:
                            continue
                    
                    if index_info:
                        has_indexes = True
                        if not has_indexes:
                            print("索引信息:")
                            has_indexes = True
                        
                        # 尝试从不同的索引信息格式中提取数据
                        index_type = "未知"
                        metric_type = "未知"
                        
                        if isinstance(index_info, dict):
                            index_type = index_info.get("index_type", index_info.get("type", "未知"))
                            metric_type = index_info.get("metric_type", index_info.get("params", {}).get("metric_type", "未知"))
                        
                        print(f"  - 字段: {field_name}, 索引类型: {index_type}, 距离度量: {metric_type}")
                except Exception as e:
                    # 忽略单个字段的索引信息获取失败
                    pass
            
            if not has_indexes:
                print("未找到索引信息或无法使用当前API获取")
                
        except Exception as e:
            print(f"获取索引信息失败: {e}")
            print("无法使用当前API版本获取索引信息")
        
        return True
    except Exception as e:
        print(f"使用MilvusClient获取集合 {collection_name} 信息失败: {e}")
        return False


def get_collection_info(collection_name, client=None):
    """获取指定集合的信息"""
    # 如果有MilvusClient实例，优先使用
    if client:
        return get_collection_info_via_client(client, collection_name)
    
    # 使用传统API
    try:
        # 加载集合
        collection = Collection(collection_name)
        
        # 获取集合行数 (新版API不再使用get_statistics)
        try:
            row_count = collection.num_entities
        except Exception as e:
            print(f"获取实体数量失败: {e}")
            row_count = "未知"
        
        # 获取集合Schema
        schema = collection.schema
        
        # 打印集合信息
        print(f"\n集合名称: {collection_name}")
        print(f"实体数量: {row_count}")
        print("字段信息:")
        
        for field in schema.fields:
            dim_info = ""
            if hasattr(field, 'params') and field.params and hasattr(field.params, 'dim'):
                dim_info = f", 维度={field.params.dim}"
            print(f"  - {field.name}: 类型={field.dtype}{dim_info}")
        
        # 获取索引信息
        index_info = []
        try:
            for field_name in collection.schema.fields:
                if field_name.name != "count" and collection.has_index(field_name=field_name.name):
                    index = collection.index(field_name=field_name.name)
                    index_info.append({
                        "field_name": field_name.name,
                        "index_type": index.params.get("index_type", "未知"),
                        "metric_type": index.params.get("metric_type", "未知")
                    })
        except Exception as e:
            print(f"获取索引信息失败: {e}")
        
        if index_info:
            print("索引信息:")
            for idx in index_info:
                print(f"  - 字段: {idx['field_name']}, 索引类型: {idx['index_type']}, 距离度量: {idx['metric_type']}")
        else:
            print("未创建索引")
            
        return True
    except Exception as e:
        print(f"获取集合 {collection_name} 信息失败: {e}")
        return False


def main():
    """主函数"""
    # 获取配置信息
    try:
        config_path = config.get_config_file_path()
        print(f"已加载配置文件: {config_path}")
        print(f"使用配置文件: {config_path}")
    except Exception as e:
        print(f"获取配置文件路径失败: {e}")
    
    # 默认集合名称
    default_collection = config.milvus_config.get("collection_name", "")
    
    # 连接Milvus
    connected, client = connect_to_milvus()
    if not connected:
        return
    
    # 获取所有集合
    all_collections = get_all_collections(client)
    
    if not all_collections:
        print("Milvus中没有找到任何集合")
        return
    
    print(f"Milvus中共有 {len(all_collections)} 个集合:")
    for i, name in enumerate(all_collections, 1):
        if name == default_collection:
            print(f"{i}. {name} (默认)")
        else:
            print(f"{i}. {name}")
    
    # 查询每个集合的信息
    print("\n===== 集合详细信息 =====")
    for collection_name in all_collections:
        get_collection_info(collection_name, client)
        print("-" * 50)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n操作已取消")
    except Exception as e:
        print(f"发生未处理的错误: {e}")
        import traceback
        traceback.print_exc() 